package server

import (
	"epg/conf"
	"epg/jmconf"
	"fmt"
	"sync"

	"github.com/pingcap/tidb/mysql"

	"github.com/zeast/logs"
	"context"
	"time"
)

//Banlancer how to banlance the request
//weight polling
type Banlancer interface {
	/*
		权重轮询调度算法
		while (true) {
		 i = (i + 1) mod n;
		 if (i == 0) {
		    cw = cw - gcd(S);
		    if (cw <= 0) {
		     cw = max(S);
		     if (cw == 0)
		       return NULL;
		    }
		 }
		if (W(Si) >= cw)
		   return Si;
		}
	*/
	Banlance() *DBContext
}

//DBContext the db pool with db current connection.
type DBContext struct {
	Conn
	db       *MysqlDB
	lastUser string //上一个使用这个链接的也用户名称
}

//GetConn get one connection, it will check can open.
func (m *MysqlDB) GetConn(user string, timeCtx context.Context) *DBContext {
	conn := m.getConn(timeCtx)
	if conn == nil {
		return nil
	}
	return &DBContext{Conn: conn, db: m, lastUser: user}
}

func (m *DBContext) addr() string {
	return m.db.addr
}

//Close release the conn
func (m *DBContext) Close(force bool) (err error) {
	// log.Debug("DBContext Close conn")
	switch c := m.Conn.(type) {
	case *mysqlConn:
		if force {
			c.lastError = mysql.ErrBadConn
		} else {
			st := c.GetStatus()
			if st&mysql.ServerStatusInTrans > 0 || st&mysql.ServerStatusAutocommit == 0 {

				logs.Error("AUTO ROLLBACK")

				var cmd  []byte
				cmd = append(cmd,mysql.ComQuery)
				cmd= append(cmd, []byte("ROLLBACK;")...)
				if _, err = c.Exec(cmd);err != nil {
					logs.Info("Rollback Error:%s", err)
				}
			}
		}
		m.db.putConn(c)

	default:
		logs.Errorf("未知的 db conn 类型, %+v", c)
	}

	conf.Limit.Put(m.lastUser, m.db.alias)

	m.Conn = nil
	return err
}

//DBPool 节点连接
type DBPool struct {
	sync.Mutex
	db     []*MysqlDB
	baddb  []*MysqlDB
	weight []int
	max    int
	i      int
	cw     int
	gcd    int
}

//Release release all the MysqlDB
func (p *DBPool) Release() error {
	p.Lock()
	defer p.Unlock()
	for _, v := range p.db {
		v.waitFree()
	}
	for _, v := range p.db {
		v.closeDB()
	}

	return nil
}

func (p *DBPool) String() string {
	str := "DBPool:"
	str += fmt.Sprintf("db:(%v):", len(p.db))
	for i, v := range p.db {
		str += fmt.Sprintf(",%v:%v;", i, v.alias)
	}
	str += fmt.Sprintf(",weight:%v", p.weight)
	str += fmt.Sprintf(",max:%v", p.max)
	str += fmt.Sprintf(",i:%v", p.i)
	str += fmt.Sprintf(",cw:%v", p.cw)
	str += fmt.Sprintf(",gcd:%v", p.gcd)

	return str
}

func newDBPool(size int) *DBPool {
	return &DBPool{
		i:      -1,
		weight: make([]int, 0, size),
		db:     make([]*MysqlDB, 0, size),
	}
}

func (p *DBPool) length() int {
	return len(p.db) + len(p.baddb)
}

func (p *DBPool) pop(cfg conf.NodeConf) {
	p.Lock()
	defer p.Unlock()

	if !p.delOKDB(cfg) {
		p.delBadDB(cfg)
	}
}

func (p *DBPool) delOKDB(cfg conf.NodeConf) bool {
	idx := p.findOKDB(cfg)

	if idx >= 0 {
		db := p.db[idx]
		db.closeDB()
		p.db = append(p.db[0:idx], p.db[idx+1:]...)
		p.weight = append(p.weight[0:idx], p.weight[idx+1:]...)
		p.max = max(p.weight)
		p.gcd = gcd(p.weight)
		return true
	}

	return false
}

func (p *DBPool) delBadDB(cfg conf.NodeConf) bool {
	idx := p.findBadDB(cfg)

	if idx >= 0 {
		db := p.baddb[idx]
		db.closeDB()
		p.baddb = append(p.baddb[:idx], p.baddb[idx+1:]...)
		return true
	}

	return false
}

func (p *DBPool) moveOKDBBad(cfg conf.NodeConf) {
	p.Lock()
	defer p.Unlock()

	idx := p.findOKDB(cfg)

	if idx >= 0 {
		db := p.db[idx]
		db.mu.Lock()
		defer db.mu.Unlock()

		p.db = append(p.db[0:idx], p.db[idx+1:]...)
		p.weight = append(p.weight[0:idx], p.weight[idx+1:]...)
		p.max = max(p.weight)
		p.gcd = gcd(p.weight)

		p.baddb = append(p.baddb, db)
		db.bad = true
		go db.reConnectCheck()
	}

}

func (p *DBPool) moveBadDBOK(cfg conf.NodeConf) {
	p.Lock()
	defer p.Unlock()

	idx := p.findBadDB(cfg)

	if idx >= 0 {
		db := p.baddb[idx]
		db.mu.Lock()
		defer db.mu.Unlock()
		db.bad = false
		p.baddb = append(p.baddb[:idx], p.baddb[idx+1:]...)

		p.db = append(p.db, db)
		p.weight = append(p.weight, cfg.Weight)
		p.max = max(p.weight)
		p.gcd = gcd(p.weight)
	}
}

func (p *DBPool) findOKDB(cfg conf.NodeConf) int {
	idx := -1

	for k, db := range p.db {
		if db.alias == cfg.Alias {
			idx = k
			break
		}
	}

	return idx
}

func (p *DBPool) findBadDB(cfg conf.NodeConf) int {
	idx := -1

	for k, db := range p.baddb {
		if db.alias == cfg.Alias {
			idx = k
			break
		}
	}

	return idx
}

func (p *DBPool) put(db *MysqlDB, weight int) {
	if p.db == nil {
		p.db = make([]*MysqlDB, 0)
	}
	if weight <= 0 {
		return
	}
	p.Lock()
	p.db = append(p.db, db)
	p.weight = append(p.weight, weight)
	if p.max < weight {
		p.max = weight
	}
	p.gcd = gcd(p.weight)
	p.Unlock()
}

//Balance implement balancer interface.
func (p *DBPool) Balance() *MysqlDB {
	if len(p.db) == 1 {
		return p.db[0]
	}
	//fixed p.db is empty
	if p.length() == 0 {
		return nil
	}
	p.Lock()
	defer p.Unlock()
	for {
		p.i = (p.i + 1) % len(p.db)
		if p.i == 0 {
			p.cw = p.cw - p.gcd
			if p.cw <= 0 {
				p.cw = p.max
				if p.cw == 0 {
					return nil
				}
			}
		}
		if p.weight[p.i] >= p.cw {
			return p.db[p.i]
		}
	}
}

func gcd(arr []int) int {
	if len(arr) == 0 {
		return 0
	}
	//find the min
	min := arr[0]
	for _, v := range arr {
		if min > v {
			min = v
		}
	}
	if min == 0 || len(arr) == 1 {
		return min
	}
	for {
		ok := true
		for _, v := range arr {
			if v%min != 0 {
				ok = false
				break
			}
		}
		if ok || min <= 1 {
			break
		}
		min--
	}
	return min
}

//NodeMgr The database node manager.
var NodeMgr = &nodeMgr{
	pools: make(map[string]*DBPool),
	cfgs:  make(map[string]conf.NodeConf),
}

//NodeMgr map the databasename to mysqlDB
type nodeMgr struct {
	sync.RWMutex
	pools map[string]*DBPool       //databasename => []mysqlDB; maybe many slave.
	cfgs  map[string]conf.NodeConf //alias => NodeConf
}

//ReleaseAll release all the dbpool.
//it must be wait all the transaction finish
func (n *nodeMgr) ReleaseAll() error {
	for _, v := range n.pools {
		logs.Info("Release the DBPool: %s", v.String())
		if err := v.Release(); err != nil {
			return err
		}
	}
	return nil
}

//AddNode add a node to dbpool.
func (n *nodeMgr) AddNode(cfg conf.NodeConf) {
	n.Lock()
	n.AddNodeNoLock(cfg)
	conf.Limit.ChangeNodeConf(cfg.Alias, cfg.MaxConnNum)
	n.Unlock()
}

func (n *nodeMgr) AddNodeNoLock(cfg conf.NodeConf) {
	db := Open(&cfg)
	for i := 0; i < len(cfg.Dbs); i++ {
		v := cfg.Dbs[i]
		p, ok := n.pools[v]
		if !ok {
			p = newDBPool(len(cfg.Dbs))
			// p = n.dbs[v]
		}
		p.put(db, cfg.Weight)
		n.pools[v] = p
	}
	n.cfgs[cfg.Alias] = cfg
	logs.Infof("NodeMgr--- Node Add: %#v", cfg)
}

func (n *nodeMgr) RemoveNode(cfg conf.NodeConf) {
	n.Lock()
	n.RemoveNodeNoLock(cfg)
	conf.Limit.ChangeNodeConf(cfg.Alias, 0)
	n.Unlock()
}

func (n *nodeMgr) RemoveNodeNoLock(cfg conf.NodeConf) {
	for i := 0; i < len(cfg.Dbs); i++ {
		v := cfg.Dbs[i]
		p, ok := n.pools[v]
		if ok {
			p.pop(cfg)
			if p.length() == 0 {
				delete(n.pools, v)
			} else {
				//fixed bug when delete the dbs, it will still in map.
				n.pools[v] = p
			}
		}
	}
	delete(n.cfgs, cfg.Alias)
	logs.Infof("NodeMgr--- Node Remove: %#v", cfg)
}

func (n *nodeMgr) moveOKDBBad(cfg conf.NodeConf) {
	if jmconf.Cfg.Role == "master" {
		return
	}
	n.Lock()
	defer n.Unlock()

	//最少保留一个,只有多个的时候才自动踢出
	for _, db := range cfg.Dbs {
		if p, ok := n.pools[db]; ok && p.length() > 1 {
			p.moveOKDBBad(cfg)
		}
	}
}

func (n *nodeMgr) moveBadDBOK(cfg conf.NodeConf) {
	if jmconf.Cfg.Role == "master" {
		return
	}

	n.Lock()
	defer n.Unlock()

	for _, db := range cfg.Dbs {
		if p, ok := n.pools[db]; ok {
			p.moveBadDBOK(cfg)
		}
	}
}

func (n *nodeMgr) ModifyNode(cfg conf.NodeConf) {
	logs.Info("NodeMgr--- Node Modify")
	oldCfg, ok := n.cfgs[cfg.Alias]
	if !ok {
		logs.Info("NodeModify add, alias:", cfg.Alias)
		n.AddNode(cfg)
	} else {
		n.Lock()
		n.RemoveNodeNoLock(oldCfg)
		n.AddNodeNoLock(cfg)
		conf.Limit.ChangeNodeConf(cfg.Alias, cfg.MaxConnNum)
		n.Unlock()
	}
}

//getMysqlDB get one MysqlDB by NodeConf, with one database name and alias
func (n *nodeMgr) getMysqlDB(cfg *conf.NodeConf) *MysqlDB {
	if len(cfg.Dbs) == 0 || cfg.Alias == "" {
		return nil
	}
	dbName := cfg.Dbs[0]
	n.RLock()
	defer n.RUnlock()
	p, ok := n.pools[dbName]
	if !ok {
		return nil
	}
	var db *MysqlDB
	for i := 0; i < len(p.db); i++ {
		if p.db[i].alias == cfg.Alias {
			db = p.db[i]
			break
		}
	}
	return db
}

//GetConn get mysqlDB with dbname
func (n *nodeMgr) GetConn(user, dbname string) (*DBContext, error) {
	n.RLock()
	defer n.RUnlock()

	p, ok := n.pools[dbname]
	if !ok {
		return nil, mysql.NewErr(mysql.ErrWrongDBName, dbname)
	}

	b := p.Balance()
	if b == nil {
		logs.Errorf("不能通过均衡算法找到合适的 DB, %#v", p)
		return nil, mysql.NewErr(mysql.ErrUnknown, "Balance can not find DBPool")
	}

	ch, err := conf.Limit.Get(user, b.alias)
	if err != nil {
		return nil, mysql.NewErr(mysql.ErrUnknown, err.Error())
	}

	 timeCtx, cancel := context.WithTimeout(context.Background(),time.Duration(jmconf.Cfg.GetMysqlConnTimeout) * time.Millisecond)
	 defer cancel()

	select {
	case <-ch:
		ctx := b.GetConn(user,timeCtx)
		if ctx == nil {
			conf.Limit.Put(user, b.alias)
			return nil, mysql.NewErrf(mysql.ErrUnknown, "没有空闲的 MySQL 连接 :%s - %s", user, b.alias)
		}

		return ctx, nil

	case <-timeCtx.Done():
		return nil, mysql.NewErrf(mysql.ErrUnknown, "用户到 MySQL 连接数受限 : %s - %s", user, b.alias)
	}
}

//NewNodeMgr new a node mgr and save it to NodeMgr
func newNodeMgr() *nodeMgr {
	nm := &nodeMgr{
		pools: make(map[string]*DBPool),
		cfgs:  make(map[string]conf.NodeConf),
	}
	// NodeMgr = nm
	return nm
}

func max(s []int) int {
	m := 0
	for i := 0; i < len(s); i++ {
		if s[i] > m {
			m = s[i]
		}
	}
	return m
}
